{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# The notebook uses the k-means clustering algorithm to derive the highest anchor ratio of the IOU."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define the core method"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Import core packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Method of IoU calculation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def iou(box, clusters):\n",
    "    \"\"\"\n",
    "    Calculates the Intersection over Union (IoU) between a box and k clusters.\n",
    "    :param box: tuple or array, shifted to the origin (i. e. width and height)\n",
    "    :param clusters: numpy array of shape (k, 2) where k is the number of clusters\n",
    "    :return: numpy array of shape (k, 0) where k is the number of clusters\n",
    "    \"\"\"\n",
    "    x = np.minimum(clusters[:, 0], box[0])\n",
    "    y = np.minimum(clusters[:, 1], box[1])\n",
    "    if np.count_nonzero(x == 0) > 0 or np.count_nonzero(y == 0) > 0:\n",
    "        raise ValueError(\"Box has no area\")\n",
    "\n",
    "    intersection = x * y\n",
    "    box_area = box[0] * box[1]\n",
    "    cluster_area = clusters[:, 0] * clusters[:, 1]\n",
    "\n",
    "    iou_ = intersection / (box_area + cluster_area - intersection)\n",
    "\n",
    "    return iou_\n",
    "\n",
    "def avg_iou(boxes, clusters):\n",
    "    \"\"\"\n",
    "    Calculates the average Intersection over Union (IoU) between a numpy array of boxes and k clusters.\n",
    "    :param boxes: numpy array of shape (r, 2), where r is the number of rows\n",
    "    :param clusters: numpy array of shape (k, 2) where k is the number of clusters\n",
    "    :return: average IoU as a single float\n",
    "    \"\"\"\n",
    "    return np.mean([np.max(iou(boxes[i], clusters)) for i in range(boxes.shape[0])])\n",
    "\n",
    "\n",
    "def translate_boxes(boxes):\n",
    "    \"\"\"\n",
    "    Translates all the boxes to the origin.\n",
    "    :param boxes: numpy array of shape (r, 4)\n",
    "    :return: numpy array of shape (r, 2)\n",
    "    \"\"\"\n",
    "    new_boxes = boxes.copy()\n",
    "    for row in range(new_boxes.shape[0]):\n",
    "        new_boxes[row][2] = np.abs(new_boxes[row][2] - new_boxes[row][0])\n",
    "        new_boxes[row][3] = np.abs(new_boxes[row][3] - new_boxes[row][1])\n",
    "    return np.delete(new_boxes, [0, 1], axis=1)\n",
    "\n",
    "\n",
    "def kmeans(boxes, k, dist=np.median):\n",
    "    \"\"\"\n",
    "    Calculates k-means clustering with the Intersection over Union (IoU) metric.\n",
    "    :param boxes: numpy array of shape (r, 2), where r is the number of rows\n",
    "    :param k: number of clusters\n",
    "    :param dist: distance function\n",
    "    :return: numpy array of shape (k, 2)\n",
    "    \"\"\"\n",
    "    rows = boxes.shape[0]\n",
    "\n",
    "    distances = np.empty((rows, k))\n",
    "    last_clusters = np.zeros((rows,))\n",
    "\n",
    "    np.random.seed()\n",
    "\n",
    "    # the Forgy method will fail if the whole array contains the same rows\n",
    "    clusters = boxes[np.random.choice(rows, k, replace=False)]\n",
    "\n",
    "    while True:\n",
    "        for row in range(rows):\n",
    "            distances[row] = 1 - iou(boxes[row], clusters)\n",
    "\n",
    "        nearest_clusters = np.argmin(distances, axis=1)\n",
    "\n",
    "        if (last_clusters == nearest_clusters).all():\n",
    "            break\n",
    "\n",
    "        for cluster in range(k):\n",
    "            clusters[cluster] = dist(boxes[nearest_clusters == cluster], axis=0)\n",
    "\n",
    "        last_clusters = nearest_clusters\n",
    "\n",
    "    return clusters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Write analysis codes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Import packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "import xml.etree.ElementTree as ET\n",
    "import cv2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Import load XML annotations function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_dataset(path, show_img, category_id):\n",
    "    dataset = []\n",
    "    paths = [p.replace(\"\\\\\", '/') for p in glob.glob(\"{}/*.xml\".format(path))]\n",
    "    gt_small, gt_mid, gt_large = 32*32, 96*96, float(\"inf\")\n",
    "    print(\"Get %d xmls\" % len(paths))\n",
    "    for xml_file in paths:\n",
    "        tree = ET.parse(xml_file)\n",
    "        # print(xml_file)\n",
    "        img_file = xml_file.replace(\"annotations\",\"images\").replace(\".xml\", \".jpg\")\n",
    "        image = cv2.imread(img_file)\n",
    "        height, width, _ = image.shape\n",
    "\n",
    "        # height = int(tree.findtext(\"./size/height\"))\n",
    "        # width = int(tree.findtext(\"./size/width\"))\n",
    "        # if H == height and W == width:\n",
    "        #     print(\"Pass.\")\n",
    "\n",
    "        # To get absolute value, remove /width and height\n",
    "        #for obj in tree.iter(\"object\"):\n",
    "        #    xmin = int(obj.findtext(\"bndbox/x0\")) / width\n",
    "        #    ymin = int(obj.findtext(\"bndbox/y0\")) / height\n",
    "        #    xmax = int(obj.findtext(\"bndbox/x1\")) / width\n",
    "        #    ymax = int(obj.findtext(\"bndbox/y1\")) / height\n",
    "        \n",
    "        for obj in tree.iter(\"object\"):\n",
    "            xmin = int(obj.findtext(\"bndbox/x0\")) \n",
    "            ymin = int(obj.findtext(\"bndbox/y0\")) \n",
    "            xmax = int(obj.findtext(\"bndbox/x1\")) \n",
    "            ymax = int(obj.findtext(\"bndbox/y1\"))\n",
    "            if category_id == 's':\n",
    "                if 0<(xmax - xmin)*(ymax - ymin)<=gt_small:\n",
    "                    dataset.append([xmax - xmin, ymax - ymin])\n",
    "                    continue\n",
    "            elif category_id == 'l':\n",
    "                if gt_mid<(xmax - xmin)*(ymax - ymin):\n",
    "                    dataset.append([xmax - xmin, ymax - ymin])\n",
    "                    continue\n",
    "            elif category_id == 'm':\n",
    "                if gt_small<(xmax - xmin)*(ymax - ymin)<=gt_mid:\n",
    "                    dataset.append([xmax - xmin, ymax - ymin])\n",
    "                    continue\n",
    "            else:\n",
    "                print(\"no category detected. Will use all possible bboxs\")\n",
    "                dataset.append([xmax - xmin, ymax - ymin])\n",
    "            image = cv2.rectangle(image, (xmin, ymin), (xmax, ymax), (255,255,0), 1, cv2.LINE_AA)\n",
    "            \n",
    "        if show_img:\n",
    "            cv2.imshow(\"loading image\", image)\n",
    "            cv2.waitKey(0)\n",
    "            cv2.destroyAllWindows()\n",
    "        \n",
    "    return np.array(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_anchor_para(bboxs, anchor_base_scale = 4, anchor_stride = 8):\n",
    "    \"\"\"\n",
    "    Compute anchor parameters, given all bboxes from kmean gathered\n",
    "    Require anchor_base_scale, anchor_stride at first feature map, it depends on network configuration\n",
    "    return anchor scale and anchor ratios\n",
    "    default parameter should work for Resnet50 backbone\n",
    "    \"\"\"\n",
    "    return_scale, return_ratio = [], []\n",
    "    base_factor = anchor_base_scale * anchor_stride\n",
    "    for height, width in bboxs:\n",
    "        return_scale.append(height*1.0/base_factor)\n",
    "        return_ratio.append((1,width*1.0/height))\n",
    "    return return_scale, return_ratio\n",
    "         "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "ANNOTATIONS_PATH = os.path.join(\"crop_data\",\"train\",\"annotations\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Number of cluster, cluster = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Get 62023 xmls\n"
     ]
    }
   ],
   "source": [
    "# load annotations to memory\n",
    "CLUSTERS = 3\n",
    "data = load_dataset(ANNOTATIONS_PATH, show_img=False, category_id = 'l')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 72.11%\n",
      "Boxes:\n",
      " [[101 133]\n",
      " [209 173]\n",
      " [153  86]]\n",
      "computed paras:  ([3.15625, 6.53125, 4.78125], [(1, 1.316831683168317), (1, 0.8277511961722488), (1, 0.5620915032679739)])\n"
     ]
    }
   ],
   "source": [
    "out = kmeans(data, k=CLUSTERS)\n",
    "print(\"Accuracy: {:.2f}%\".format(avg_iou(data, out) * 100))\n",
    "print(\"Boxes:\\n {}\".format(out))\n",
    "\n",
    "#ratios = np.around(1.0 / out[:, 0] * out[:, 1], decimals=2).tolist()\n",
    "#print(\"Ratios:\\n {}\".format(sorted(ratios)))\n",
    "print(\"computed paras: \", compute_anchor_para(out))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Number of cluster, cluster = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Get 62023 xmls\n",
      "Accuracy: 76.53%\n",
      "Boxes:\n",
      " [[195 113]\n",
      " [146  81]\n",
      " [ 97 124]\n",
      " [129 178]\n",
      " [250 206]]\n",
      "computed paras:  ([6.09375, 4.5625, 3.03125, 4.03125, 7.8125], [(1, 0.5794871794871795), (1, 0.5547945205479452), (1, 1.2783505154639174), (1, 1.37984496124031), (1, 0.824)])\n"
     ]
    }
   ],
   "source": [
    "CLUSTERS = 5\n",
    "data = load_dataset(ANNOTATIONS_PATH, show_img=False, category_id = 'l')\n",
    "out = kmeans(data, k=CLUSTERS)\n",
    "print(\"Accuracy: {:.2f}%\".format(avg_iou(data, out) * 100))\n",
    "print(\"Boxes:\\n {}\".format(out))\n",
    "\n",
    "print(\"computed paras: \", compute_anchor_para(out))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Number of cluster, cluster = 7"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Get 62023 xmls\n",
      "Accuracy: 79.64%\n",
      "Boxes:\n",
      " [[127 190]\n",
      " [154  77]\n",
      " [208 108]\n",
      " [ 86 139]\n",
      " [144 131]\n",
      " [112 101]\n",
      " [252 206]]\n",
      "computed paras:  ([3.96875, 4.8125, 6.5, 2.6875, 4.5, 3.5, 7.875], [(1, 1.4960629921259843), (1, 0.5), (1, 0.5192307692307693), (1, 1.6162790697674418), (1, 0.9097222222222222), (1, 0.9017857142857143), (1, 0.8174603174603174)])\n"
     ]
    }
   ],
   "source": [
    "CLUSTERS = 7\n",
    "data = load_dataset(ANNOTATIONS_PATH, show_img=False, category_id = 'l')\n",
    "out = kmeans(data, k=CLUSTERS)\n",
    "print(\"Accuracy: {:.2f}%\".format(avg_iou(data, out) * 100))\n",
    "print(\"Boxes:\\n {}\".format(out))\n",
    "\n",
    "print(\"computed paras: \", compute_anchor_para(out))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Number of cluster, cluster = 9"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Get 62023 xmls\n"
     ]
    }
   ],
   "source": [
    "CLUSTERS = 9\n",
    "data = load_dataset(ANNOTATIONS_PATH, show_img=False, category_id = 'l')\n",
    "out = kmeans(data, k=CLUSTERS)\n",
    "print(\"Accuracy: {:.2f}%\".format(avg_iou(data, out) * 100))\n",
    "print(\"Boxes:\\n {}\".format(out))\n",
    "\n",
    "print(\"computed paras: \", compute_anchor_para(out))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
